% By Martin M. Andreasen, Auguest 2009
% This function computes the conditional moments up to k_period into the
% future for the variables selected by "y_select" in the g-function. 
% This implementation allows for non-zero third moments in the shock
% distributions.
% Note that y_select must have the form of 1 x dimy, where dimy is the
% number of variables in the g-functions we want to compute conditional
% moments for. The steady state value of y_select should be stored in y_ss
%
% In terms of the notation we solve: p(x,sig) = E_t[r(x_t+1,sig)]
% The computations are derived by using the Perturbation On Perturbation (POP) method.
% The integer "T_function" has the following properties:
% 1  - if log-transformation of bond prices is used and 
% 0  - if no log-transformation of bond prices is used
%
function [px,pxx,pss,pxxx,pssx,psss] = CondMoments_Mom3(gx,gxx,gss,gxxx,gssx,gsss,...
    hx,hxx,hss,hxxx,hssx,hsss,eta,vectorMom3,y_select,y_ss,k_period,T_function);

% Allocating memory
dimy    = size(y_select,2);
nx      = size(hx,1);
ne      = size(eta,2);
px      = zeros(dimy,k_period,nx);
pxx     = zeros(dimy,k_period,nx,nx);
pss     = zeros(dimy,k_period);
pxxx    = zeros(dimy,k_period,nx,nx,nx);    
pssx    = zeros(dimy,k_period,nx);
psss    = zeros(dimy,k_period);
I_ne    = eye(ne);

for i=1:dimy
    rx   = reshape(gx(y_select(1,i),:),1,nx);
    rxx  = reshape(gxx(y_select(1,i),:,:),nx,nx);
    rss  = gss(y_select(1,i),1);
    rxxx = reshape(gxxx(y_select(1,i),:,:,:),nx,nx,nx);
    rssx = reshape(gssx(y_select(1,i),:),1,nx);
    rsss = gsss(y_select(1,i),1);
    
    if T_function == 1
        % Log-transformation is used of bond prices
        T    = y_ss(1,y_select(1,i));
        Tp   = T;
        Tpp  = T;    
        Tppp = T;    
        Tr   = Tp;
        Trr  = Tpp;    
        Trrr = Tppp;  
    elseif T_function == 0
        % Log-transformation is NOT used of bond prices    
        T   = y_ss(1,y_select(1,i));
        Tp  = 1;
        Tpp = 0;    
        Tppp= 0;  
        Tr   = Tp;
        Trr  = Tpp;    
        Trrr = Tppp;
    end    
    
    
    for j=1:k_period
       % ************** The first order effects *****************
       px(i,j,:) = Tp*rx(1,:)*hx;
       px(i,j,:) = px(i,j,:)/Tp;
       
       
       %*************** The second order effects ***************
       for alfa1=1:nx
           for alfa2=alfa1:nx
               pxx(i,j,alfa1,alfa2) = -Tpp*px(i,j,alfa1)*px(i,j,alfa2) + ...
                   +Trr*rx(1,:)*hx(:,alfa1)*rx*hx(:,alfa2) ...
                   +Tr*hx(:,alfa1)'*rxx*hx(:,alfa2) ...
                   +Tr*rx(1,:)*squeeze(hxx(:,alfa1,alfa2));
           end
           if alfa1 > 1
              pxx(i,j,alfa1,1:alfa1-1) = squeeze(pxx(i,j,1:alfa1-1,alfa1))';
           end
       end
       pxx(i,j,:,:) = pxx(i,j,:,:)/Tp;
       %checking the log-formula - ok
       %tmp = zeros(nx,nx);
       %for gama1=1:nx
       %    tmp = tmp + rx(1,gama1)*squeeze(hxx(gama1,:,:));
       %end
       %pxx(i,j,:,:) = hx'*rxx*hx + tmp;
       
       pss(i,j) = Tr*rx(1,:)*hss + Tr*rss;
       for phi2=1:ne
           pss(i,j) = pss(i,j) + Trr*rx(1,:)*eta(:,phi2)*rx(1,:)*eta*I_ne(:,phi2) ...
               + Tr*(eta*I_ne(:,phi2))'*rxx*eta(:,phi2);
       end
       pss(i,j) = pss(i,j)/Tp;
       % Checking the log-formula - OK
       %pss(i,j) = rx(1,:)*eta*eta'*rx(1,:)' + trace(eta'*rxx*eta) + rx(1,:)*hss + rss;
       
       
       % ***************** The third order effects *****************
       for alfa1=1:nx
           for alfa2=alfa1:nx
               for alfa3=alfa2:nx
                   tmp = 0;
                   for gama3=1:nx
                       tmp = tmp + Tr*hx(:,alfa1)'*squeeze(rxxx(:,:,gama3))*hx(:,alfa2)*hx(gama3,alfa3);
                   end
                   pxxx(i,j,alfa1,alfa2,alfa3) = -Tppp*px(i,j,alfa1)*px(i,j,alfa2)*px(i,j,alfa3)...
                       -Tpp*pxx(i,j,alfa1,alfa3)*px(i,j,alfa2)...
                       -Tpp*px(i,j,alfa1)*pxx(i,j,alfa2,alfa3)...
                       -Tpp*pxx(i,j,alfa1,alfa2)*px(i,j,alfa3)...
                       +Trrr*rx(1,:)*hx(:,alfa1)*rx(1,:)*hx(:,alfa2)*rx(1,:)*hx(:,alfa3)...
                       +Trr*hx(:,alfa1)'*rxx*hx(:,alfa3)*rx(1,:)*hx(:,alfa2)...
                       +Trr*rx(1,:)*squeeze(hxx(:,alfa1,alfa3))*rx(1,:)*hx(:,alfa2)...
                       +Trr*rx(1,:)*hx(:,alfa1)*hx(:,alfa2)'*rxx*hx(:,alfa3)...
                       +Trr*rx(1,:)*hx(:,alfa1)*rx(1,:)*squeeze(hxx(:,alfa2,alfa3))...
                       +Trr*rx(1,:)*hx(:,alfa3)*hx(:,alfa1)'*rxx*hx(:,alfa2)...
                       +tmp...
                       +Tr*hx(:,alfa1)'*rxx*reshape(hxx(:,alfa2,alfa3),nx,1)...
                       +Tr*reshape(hxx(:,alfa1,alfa3),1,nx)*rxx*hx(:,alfa2)...
                       +Trr*rx(1,:)*hx(:,alfa3)*rx(1,:)*squeeze(hxx(:,alfa1,alfa2))...
                       +Tr*reshape(hxx(:,alfa1,alfa2),1,nx)*rxx*hx(:,alfa3)...
                       +Tr*rx(1,:)*reshape(hxxx(:,alfa1,alfa2,alfa3),nx,1);
                   
                   % Using symmetry for alfa1 and alfa2
                   if alfa1 == alfa2 && alfa2 ~= alfa3 %alfa1==alfa2~=alfa3
                      %pxxx(i,j,alfa1,alfa1,alfa3)= pxxx(i,j,alfa1,alfa2,alfa3);
                      pxxx(i,j,alfa1,alfa3,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                      pxxx(i,j,alfa3,alfa1,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);                
                   end
                   % Using symmetry for alfa2 and alfa3            
                   if alfa1 ~= alfa2 && alfa2 == alfa3  %alfa1~=alfa2==alfa3 
                      %pxxx(i,j,alfa1,alfa2,alfa2)= pxxx(i,j,alfa1,alfa2,alfa3);                
                      pxxx(i,j,alfa2,alfa1,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3);                
                      pxxx(i,j,alfa2,alfa2,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);                                
                   end            
                   % Using symmetry for alfa1,alfa2, and alfa3            
                   if alfa1 ~= alfa2 && alfa1 ~= alfa3 &&  alfa2 ~= alfa3 %alfa1~=alfa2~=alfa3
                      %pxxx(i,j,alfa1,alfa2,alfa3) = pxxx(i,j,alfa1,alfa2,alfa3);
                      pxxx(i,j,alfa1,alfa3,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3);
                      pxxx(i,j,alfa3,alfa1,alfa2) = pxxx(i,j,alfa1,alfa2,alfa3);
                      pxxx(i,j,alfa3,alfa2,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3);
                      pxxx(i,j,alfa2,alfa3,alfa1) = pxxx(i,j,alfa1,alfa2,alfa3); 
                      pxxx(i,j,alfa2,alfa1,alfa3) = pxxx(i,j,alfa1,alfa2,alfa3);                 
                   end
               end
           end
       end
       pxxx(i,j,:,:,:) = pxxx(i,j,:,:,:)/Tp;
       % Checking the log-formula - OK
       %for alfa1=1:nx
       %    for alfa2=1:nx
       %        for alfa3=1:nx
       %            tmp = 0;
       %            for gama3=1:nx
       %                tmp = tmp + hx(:,alfa1)'*reshape(rxxx(:,:,gama3),nx,nx)*hx(:,alfa2)*hx(gama3,alfa3);
       %            end
       %            pxxx(i,j,alfa1,alfa2,alfa3) = tmp  ...
       %                                         + hx(:,alfa1)'*rxx*reshape(hxx(:,alfa2,alfa3),nx,1) ...
       %                                         + reshape(hxx(:,alfa1,alfa3),1,nx)*rxx*hx(:,alfa2)...
       %                                         + reshape(hxx(:,alfa1,alfa2),1,nx)*rxx*hx(:,alfa3)...
       %                                         + rx(1,:)*reshape(hxxx(:,alfa1,alfa2,alfa3),nx,1);
       %        end
       %    end
       %end
       % More efficient version of the log-formula: OK
       %for alfa1=1:nx
       %    for alfa2=1:nx
       %         tmp = zeros(1,nx);
       %         for gama3=1:nx
       %             tmp = tmp + hx(:,alfa1)'*reshape(rxxx(:,:,gama3),nx,nx)*hx(:,alfa2)*hx(gama3,:);
       %         end
       %         pxxx(i,j,alfa1,alfa2,:) = tmp  ...
       %                                 + hx(:,alfa1)'*rxx*reshape(hxx(:,alfa2,:),nx,nx) ...
       %                                 + hx(:,alfa2)'*rxx*reshape(hxx(:,alfa1,:),nx,nx)...
       %                                 + reshape(hxx(:,alfa1,alfa2),1,nx)*rxx*hx(:,:)...
       %                                 + rx(1,:)*reshape(hxxx(:,alfa1,alfa2,:),nx,nx);
       %    end
       %end       
       
       tmp = zeros(1,nx);
       for phi2=1:ne
           tmp = tmp + Trrr*rx(1,:)*hx(:,:)*(rx(1,:)*eta(:,phi2)*rx(1,:)*eta*I_ne(:,phi2))...
                     + 3*Trr*eta(:,phi2)'*rxx*hx(:,:)*(rx(1,:)*eta*I_ne(:,phi2));
           for gama3=1:nx
              tmp = tmp + Tr*(eta(:,:)*I_ne(:,phi2))'*squeeze(rxxx(:,:,gama3))*eta(:,phi2)*hx(gama3,:);
           end
       end
       pssx(i,j,:) = -Tpp*reshape(px(i,j,:),1,nx)*pss(i,j)...
               + tmp ...
               + Trr*rx(1,:)*hx(:,:)*(rx(1,:)*hss) ...
               + Tr*hss'*rxx*hx(:,:)...
               + Tr*rx(1,:)*hssx(:,:)...
               + Trr*rx(1,:)*hx(:,:)*rss ...
               + Tr*rssx*hx(:,:);
       pssx(i,j,:) = pssx(i,j,:)/Tp;
       % Checking the log-formula - OK
       %pssx(i,j,:) = 2*rx(1,:)*eta*eta'*rxx*hx+hss'*rxx*hx+rx(1,:)*hssx+rssx*hx;
       %for gama3=1:nx
       %    pssx(i,j,:) = reshape(pssx(i,j,:),1,nx) + trace(eta'*squeeze(rxxx(:,:,gama3))*eta)*hx(gama3,:);
       %end
       
       tmp = 0;
       for phi1=1:ne
           tmp = tmp + Trrr*rx(1,:)*eta(:,phi1)*rx(1,:)*eta(:,phi1)*rx(1,:)*eta(:,phi1)*vectorMom3(phi1)...
                     + 3*Trr*eta(:,phi1)'*rxx*eta(:,phi1)*rx(1,:)*eta(:,phi1)*vectorMom3(phi1);
       end
       for gama1=1:nx
           for phi1=1:ne
               tmp = tmp + Tr*eta(:,phi1)'*reshape(rxxx(gama1,:,:),nx,nx)*eta(:,phi1)*eta(gama1,phi1)*vectorMom3(phi1);
           end
       end
       tmp = tmp + Tr*rx(1,:)*hsss + Tr*rsss;
       psss(i,j) = tmp/Tp;
   
       
       
       % Updating rx
       rx   = reshape(px(i,j,:),1,nx);
       rxx  = reshape(pxx(i,j,:,:),nx,nx);
       rss  = pss(i,j);
       rxxx = reshape(pxxx(i,j,:,:,:),nx,nx,nx);
       rssx = reshape(pssx(i,j,:),1,nx);       
       rsss = psss(i,j);
    end
end

% We reexpress the output if dimy = 1
if dimy == 1
    px      = squeeze(px(dimy,1:k_period,1:nx));
    pxx     = squeeze(pxx(dimy,1:k_period,1:nx,1:nx));
    pss     = reshape(pss(dimy,1:k_period),k_period,1);
    pxxx    = squeeze(pxxx(dimy,1:k_period,1:nx,1:nx,1:nx));    
    pssx    = squeeze(pssx(dimy,1:k_period,1:nx));
    psss    = reshape(psss(dimy,1:k_period),k_period,1);
end